{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "E:\\software\\Anaconda3\\envs\\tf2.0\\lib\\site-packages\\ipykernel\\parentpoller.py:116: UserWarning: Parent poll failed.  If the frontend dies,\n",
      "                the kernel may be left running.  Please let us know\n",
      "                about your system (bitness, Python, etc.) at\n",
      "                ipython-dev@scipy.org\n",
      "  ipython-dev@scipy.org\"\"\")\n"
     ]
    }
   ],
   "source": [
    "# 初始化\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import sys\n",
    "import os\n",
    "os.chdir('E:\\GitHub\\QA-abstract-and-reasoning')\n",
    "sys.path.append('E:\\GitHub\\QA-abstract-and-reasoning')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 Physical GPUs, 1 Logical GPUs\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "from pgn.layers import BahdanauAttention, Encoder, Pointer\n",
    "from pgn.batcher import batcher\n",
    "from pgn.model import PGN\n",
    "from utils.saveLoader import load_embedding_matrix\n",
    "from utils.saveLoader import Vocab\n",
    "from utils.config import VOCAB_PAD\n",
    "from utils.config_gpu import config_gpu\n",
    "config_gpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "%run utils/params.py\n",
    "# 产生输入数据\n",
    "vocab = Vocab(VOCAB_PAD)\n",
    "ds =iter(batcher(vocab, params))\n",
    "enc_data, dec_data = next(ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 Physical GPUs, 1 Logical GPUs\n"
     ]
    }
   ],
   "source": [
    "config_gpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 Physical GPUs, 1 Logical GPUs\n",
      "Building the model ...\n",
      "Creating the vocab ...\n",
      "Creating the checkpoint manager\n",
      "Restored from E:\\GitHub\\QA-abstract-and-reasoning\\data\\checkpoints\\pgn_checkpoints\\ckpt-1\n",
      "Model restored\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "print(\"Building the model ...\")\n",
    "\n",
    "model = PGN(params)\n",
    "\n",
    "print(\"Creating the vocab ...\")\n",
    "vocab = Vocab(params[\"vocab_path\"], params[\"vocab_size\"])\n",
    "# ds = batcher(vocab, params)\n",
    "\n",
    "print(\"Creating the checkpoint manager\")\n",
    "checkpoint = tf.train.Checkpoint(Seq2Seq=model)\n",
    "\n",
    "checkpoint_manager = tf.train.CheckpointManager(checkpoint, PGN_CKPT, max_to_keep=5)\n",
    "\n",
    "# checkpoint_manager = tf.train.CheckpointManager(checkpoint, TEMP_CKPT, max_to_keep=5)\n",
    "# temp_ckpt = os.path.join(TEMP_CKPT, \"ckpt-5\")\n",
    "# checkpoint.restore(temp_ckpt)\n",
    "\n",
    "checkpoint.restore(checkpoint_manager.latest_checkpoint)\n",
    "if checkpoint_manager.latest_checkpoint:\n",
    "    print(\"Restored from {}\".format(checkpoint_manager.latest_checkpoint))\n",
    "else:\n",
    "    print(\"Initializing from scratch.\")\n",
    "print(\"Model restored\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# oov index的decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pgn.test_helper import output_id_to_word"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "32233"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab.count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=233, shape=(4,), dtype=string, numpy=array([b'\\xe5\\x87\\xba\\xe8\\x84\\xb8', b'', b'', b''], dtype=object)>"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "article_oovs = enc_data[\"article_oovs\"][3]\n",
    "article_oovs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'出脸'"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idx = 32233\n",
    "output_id_to_word(idx, vocab, article_oovs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "dec_input = tf.constant([vocab.word2id[vocab.START_DECODING]] * 64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "dec_input1 = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(10):\n",
    "    dec_input1.append(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=242, shape=(64,), dtype=int32, numpy=\n",
       "array([32229, 32229, 32229, 32229, 32229, 32229, 32229, 32229, 32229,\n",
       "       32229, 32229, 32229, 32229, 32229, 32229, 32229, 32229, 32229,\n",
       "       32229, 32229, 32229, 32229, 32229, 32229, 32229, 32229, 32229,\n",
       "       32229, 32229, 32229, 32229, 32229, 32229, 32229, 32229, 32229,\n",
       "       32229, 32229, 32229, 32229, 32229, 32229, 32229, 32229, 32229,\n",
       "       32229, 32229, 32229, 32229, 32229, 32229, 32229, 32229, 32229,\n",
       "       32229, 32229, 32229, 32229, 32229, 32229, 32229, 32229, 32229,\n",
       "       32229])>"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dec_input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "32232"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab.word2id[vocab.UNKNOWN_TOKEN]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = \"123456789\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.index(\"56789\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyError",
     "evalue": "32233",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyError\u001b[0m                                  Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-28-e186ec3e90e9>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mvocab\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mid2word\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m32233\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;31mKeyError\u001b[0m: 32233"
     ]
    }
   ],
   "source": [
    "vocab.id2word[32233]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:tf2.0]",
   "language": "python",
   "name": "conda-env-tf2.0-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
